Skip to content

fix(mla): widen page index to int64_t to avoid 32-bit overflow#3136

Open
Tracin wants to merge 2 commits intoflashinfer-ai:mainfrom
Tracin:fix_mla
Open

fix(mla): widen page index to int64_t to avoid 32-bit overflow#3136
Tracin wants to merge 2 commits intoflashinfer-ai:mainfrom
Tracin:fix_mla

Conversation

@Tracin
Copy link
Copy Markdown

@Tracin Tracin commented Apr 21, 2026

📌 Description

In the MLA decode/prefill KV load path, indices[q] * ckv_stride_page was computed in 32-bit because IdType is int32_t and *_stride_page is uint32_t; the product wraps modulo 2^32 before any widening to int64_t (Hopper) or pointer arithmetic (FA2). For large page pools (e.g. page_idx ~1M with page_size=32, kv_lora_rank=512, stride=16384) the true product exceeds 2^32 and the kernel reads the wrong page, producing all-zero outputs. Cast the selected page index to int64_t at all three sites (mla.cuh NUM_MMA_KV==1 and !=1 branches, and mla_hopper.cuh prefetch_offset) so the multiply executes in 64-bit.

🔍 Related Issues

#3130

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes
    • Fixed potential integer overflow in page-index arithmetic that could cause incorrect memory page selection for very large attention caches.
  • Tests
    • Added a regression test that exercises page-index overflow scenarios to ensure correct behavior on large KV caches.

…address computation

In the MLA decode/prefill KV load path, `indices[q] * ckv_stride_page`
was computed in 32-bit because `IdType` is `int32_t` and `*_stride_page`
is `uint32_t`; the product wraps modulo 2^32 before any widening to
`int64_t` (Hopper) or pointer arithmetic (FA2). For large page pools
(e.g. page_idx ~1M with page_size=32, kv_lora_rank=512, stride=16384)
the true product exceeds 2^32 and the kernel reads the wrong page,
producing all-zero outputs. Cast the selected page index to `int64_t`
at all three sites (mla.cuh NUM_MMA_KV==1 and !=1 branches, and
mla_hopper.cuh prefetch_offset) so the multiply executes in 64-bit.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 21, 2026

📝 Walkthrough

Walkthrough

Widened KV page-index arithmetic to 64-bit in two CUDA attention sources to prevent uint32 overflow during stride multiplications; added a regression test that exercises page-index overflow cases for MLA decode kernels.

Changes

Cohort / File(s) Summary
MLA KV Loading
include/flashinfer/attention/mla.cuh
Cast computed KV page index to int64_t before multiplying by ckv_stride_page/kpe_stride_page in load_kv (both single- and multi-MMA paths).
MLA Hopper Prefetching
include/flashinfer/attention/mla_hopper.cuh
Cast page-offset multiplicand to int64_t before multiplying by _stride_page when computing ckv_offset/kpe_offset in prefetch_offset.
Tests
tests/attention/test_mla_decode_kernel.py
Added regression test test_mla_page_index_uint32_overflow_regression that constructs a large KV cache and verifies correct page selection when indices would overflow uint32.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • yzh119
  • saltyminty
  • bkryu
  • qsang-nv
  • nv-yunzheq

Poem

🐰 I hopped through pages, counted wide,

From thirty-two I learned to bide,
Now sixty-four keeps strides in line,
No overflow — the cache is fine,
Kernel hums and rabbits glide.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main fix: widening page index to int64_t to prevent 32-bit overflow in MLA operations.
Description check ✅ Passed The description includes all required sections from the template: detailed explanation of the bug and fix, related issue link, pre-commit completion confirmation, and notes about testing status.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses potential 32-bit integer overflows in KV cache offset calculations within mla.cuh and mla_hopper.cuh by casting page indices to int64_t. Feedback suggests that the entire offset calculation should be promoted to 64-bit to prevent overflows in subsequent additions and to improve future-proofing.

Comment thread include/flashinfer/attention/mla.cuh
@qsang-nv
Copy link
Copy Markdown
Collaborator

Nice fix — cast is in the right place and all three call sites are covered.

One suggestion: add a minimal regression test that forces page_idx * stride_page > 2^32. This is a silent-output-corruption bug (wrong output, no crash), so without a guard test, a future refactor of IdType or *_stride_page types could easily reintroduce it.

You don't need a huge KV cache — a sparse kv_indices with a few large index values pointing at a small real allocation should hit the overflow path in ~30 lines in tests/attention/test_mla_decode_kernel.py.

@Tracin
Copy link
Copy Markdown
Author

Tracin commented Apr 22, 2026

@qsang-nv Thanks for the review! However I do not get how large index values pointing at a small real allocation. I suppose we need a real address for large page_idx * stride_page.

@qsang-nv
Copy link
Copy Markdown
Collaborator

@Tracin You are right, we do need a real address, however, it can be smaller than the script you provided in the issue.

The overflow triggers as soon as page_idx × ckv_stride_page exceeds 2^32, so we just need page_idx to be just past that threshold — not at 1M.

Overflow threshold:

  • ckv_stride_page = ckv_cache.stride(0) is in elements, and for a contiguous [num_pages, page_size, kv_lora_rank] tensor equals page_size × kv_lora_rank = 32 × 512 = 16384.
  • Overflow happens when page_idx × 16384 ≥ 2^32.
  • 2^32 / 16384 = 262144, so the minimum page_idx that triggers it is 262144.

Allocation needed:

Tensors are contiguous, so to have page_idx = 262144 as a valid index, max_pages must be at least 262145. Using max_pages = 262170 to match the 26-page decode scenario in your debug script:

  • ckv_cache elements = 262170 × 32 × 512 = 4,294,983,6804.295 × 10^9 (bf16 → ×2 bytes) → ~8.59 GB
  • kpe_cache elements = 262170 × 32 × 64 = 536,884,480 (bf16 → ×2 bytes) → ~1.07 GB
  • Total ≈ 9.66 GB

Fits on most modern GPUs. Gate with @pytest.mark.skipif(torch.cuda.mem_get_info()[0] < 12 * 1024**3, ...) to skip on small-memory runners. I think we can add this test to CI as-is first. If the memory footprint turns out to cause trouble on some runner, we can remove it as I see there are comments waring about overflow.

Exercises the int64 widening added in a716f93 by running a 26-page
MLA decode with page indices starting at 262144 — the smallest index
that makes `indices[q] * ckv_stride_page` overflow uint32 for a
contiguous [*, 32, 512] cache (stride(0) = 16384). Compared against
a reference run with the same data and stride but non-overflowing
indices; pre-fix, the big-index run silently reads the wrong page
and produces garbage output with no crash.

Self-skips when free VRAM is below 12 GiB (the big cache alone is
~9.66 GiB).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Tracin
Copy link
Copy Markdown
Author

Tracin commented Apr 23, 2026

@qsang-nv I see. Test is added!

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/attention/test_mla_decode_kernel.py (1)

571-578: Avoid zero-filling the entire ~9.6 GiB cache.

Only the wrapped low pages and overflow suffix are observed by this regression. Using empty() and zeroing the wrapped pages keeps the pre-fix failure deterministic without the full-cache memset cost.

Proposed refactor
-    ckv_big = torch.zeros(
+    ckv_big = torch.empty(
         total_num_pages, page_size, head_dim_ckv, device=device, dtype=dtype
     )
-    kpe_big = torch.zeros(
+    kpe_big = torch.empty(
         total_num_pages, page_size, head_dim_kpe, device=device, dtype=dtype
     )
+    ckv_big[:NUM_PAGES].zero_()
+    kpe_big[:NUM_PAGES].zero_()
     ckv_big[OVERFLOW_START:] = real_ckv
     kpe_big[OVERFLOW_START:] = real_kpe
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_mla_decode_kernel.py` around lines 571 - 578, The test
currently allocates ckv_big and kpe_big with torch.zeros(...) which zero-fills
the entire ~9.6 GiB cache; instead allocate with torch.empty(total_num_pages,
page_size, head_dim_ckv/ head_dim_kpe, device=device, dtype=dtype) for
ckv_big/kpe_big and only zero the observed regions: the wrapped low pages slice
and the overflow suffix (use OVERFLOW_START and the indices for the wrapped
pages derived from total_num_pages/page_size and wrap logic) then assign
real_ckv/real_kpe into ckv_big[OVERFLOW_START:] and kpe_big[OVERFLOW_START:];
this preserves deterministic pre-fix behavior while avoiding a full-cache
memset.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/attention/test_mla_decode_kernel.py`:
- Around line 512-522: Add the existing architecture-error skip wrapper to the
parametrized test by decorating test_mla_page_index_uint32_overflow_regression
with `@skip_on_gpu_arch_error` (imported from flashinfer.utils or the test
utilities), so that GPU-architecture-related exceptions raised for either
backend are handled; keep the existing `@pytest.mark.parametrize`("backend",
["fa2", "fa3"]) and the in-body SM90a guard (is_sm90a_supported) for fa3, but
place the `@skip_on_gpu_arch_error` decorator directly above the test function
definition to match other MLA decode tests.

---

Nitpick comments:
In `@tests/attention/test_mla_decode_kernel.py`:
- Around line 571-578: The test currently allocates ckv_big and kpe_big with
torch.zeros(...) which zero-fills the entire ~9.6 GiB cache; instead allocate
with torch.empty(total_num_pages, page_size, head_dim_ckv/ head_dim_kpe,
device=device, dtype=dtype) for ckv_big/kpe_big and only zero the observed
regions: the wrapped low pages slice and the overflow suffix (use OVERFLOW_START
and the indices for the wrapped pages derived from total_num_pages/page_size and
wrap logic) then assign real_ckv/real_kpe into ckv_big[OVERFLOW_START:] and
kpe_big[OVERFLOW_START:]; this preserves deterministic pre-fix behavior while
avoiding a full-cache memset.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 090ae1b8-3ef6-4e48-befe-4b6b5ff513ed

📥 Commits

Reviewing files that changed from the base of the PR and between a716f93 and 35b56e1.

📒 Files selected for processing (1)
  • tests/attention/test_mla_decode_kernel.py

Comment on lines +512 to +522
@pytest.mark.parametrize("backend", ["fa2", "fa3"])
def test_mla_page_index_uint32_overflow_regression(backend):
# Regression for the int64 widening in mla.cuh / mla_hopper.cuh
# (`indices[q] * ckv_stride_page`). For a contiguous
# [num_pages, page_size, head_dim_ckv] cache with page_size=32 and
# head_dim_ckv=512, ckv_stride_page = 16384 elements. Any page index
# >= 2^32 / 16384 = 262144 makes the multiplication overflow uint32 and
# — pre-fix — silently wraps to the wrong page (no crash, wrong output).
device = torch.device("cuda:0")
if backend == "fa3" and not is_sm90a_supported(device):
pytest.skip("fa3 backend requires SM90a")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add the existing architecture-error skip wrapper to this backend-parametrized test.

fa3 is gated, but fa2 or backend dispatch can still raise an unsupported-architecture error on some runners. Match the existing MLA decode test by wrapping this regression with @skip_on_gpu_arch_error.

Proposed fix
+@skip_on_gpu_arch_error
 `@pytest.mark.parametrize`("backend", ["fa2", "fa3"])
 def test_mla_page_index_uint32_overflow_regression(backend):

As per coding guidelines, tests/**/*.py: “Skip test execution on unsupported GPU architectures using flashinfer.utils check functions (is_sm90a_supported(), is_sm100a_supported(), etc.) or API methods like api_name.is_compute_capability_supported(cc)”.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_mla_decode_kernel.py` around lines 512 - 522, Add the
existing architecture-error skip wrapper to the parametrized test by decorating
test_mla_page_index_uint32_overflow_regression with `@skip_on_gpu_arch_error`
(imported from flashinfer.utils or the test utilities), so that
GPU-architecture-related exceptions raised for either backend are handled; keep
the existing `@pytest.mark.parametrize`("backend", ["fa2", "fa3"]) and the in-body
SM90a guard (is_sm90a_supported) for fa3, but place the `@skip_on_gpu_arch_error`
decorator directly above the test function definition to match other MLA decode
tests.

@qsang-nv
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !586 has been created, and the CI pipeline #49261873 is currently running. I'll report back once the pipeline job completes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants